{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "886ed769-07f9-474f-9f86-c9f1f17345e4",
   "metadata": {},
   "source": [
    "## 使用微调后的 LLaMA2-7B 推理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24f9cf61-2994-48c7-9a42-150920397a93",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1c571fa5-aa51-495c-8b3d-d5605a02e491",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "70a30ffcc759412186e57f3c13521b52",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n",
      "/root/miniconda3/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from peft import AutoPeftModelForCausalLM\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "\n",
    "model_dir = \"models/llama-7-int4-dolly\"\n",
    " \n",
    "# 加载基础LLM模型与分词器\n",
    "model = AutoPeftModelForCausalLM.from_pretrained(\n",
    "    model_dir,\n",
    "    low_cpu_mem_usage=True,\n",
    "    torch_dtype=torch.float16,\n",
    "    load_in_4bit=True,\n",
    ") \n",
    "tokenizer = AutoTokenizer.from_pretrained(model_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7700c042-20cf-4ff1-8c40-a91cabeaaaa6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n",
      "/root/miniconda3/lib/python3.11/site-packages/bitsandbytes/nn/modules.py:226: UserWarning: Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.\n",
      "  warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prompt:\n",
      "Football (or soccer) is the world's most popular sport. Others include cricket, hockey, tennis, volleyball, table tennis, and basketball. This might come as a surprise to Americans, who favor (American) football.\n",
      "\n",
      "Generated instruction:\n",
      "Which is the most popular sport in the world?\n",
      "\n",
      "Ground truth:\n",
      "What are the world's most popular sports?\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset \n",
    "from random import randrange\n",
    " \n",
    " \n",
    "# 从hub加载数据集并得到一个样本\n",
    "dataset = load_dataset(\"databricks/databricks-dolly-15k\", split=\"train\")\n",
    "sample = dataset[randrange(len(dataset))]\n",
    " \n",
    "prompt = f\"\"\"### Instruction:\n",
    "Use the Input below to create an instruction, which could have been used to generate the input using an LLM. \n",
    " \n",
    "### Input:\n",
    "{sample['response']}\n",
    " \n",
    "### Response:\n",
    "\"\"\"\n",
    " \n",
    "input_ids = tokenizer(prompt, return_tensors=\"pt\", truncation=True).input_ids.cuda()\n",
    "\n",
    "outputs = model.generate(input_ids=input_ids, max_new_tokens=100, do_sample=True, top_p=0.9,temperature=0.9)\n",
    "\n",
    "print(f\"Prompt:\\n{sample['response']}\\n\")\n",
    "print(f\"Generated instruction:\\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}\")\n",
    "print(f\"Ground truth:\\n{sample['instruction']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f6ea5b0-cf32-4cc0-bd5e-fc17268992c0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
